-
Notifications
You must be signed in to change notification settings - Fork 162
Update distill Megatron plugin #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughThe distillation API was updated to pass arbitrary kwargs to per-layer loss functions via compute_kd_loss. In the Megatron plugin, a new DistillationConfig dataclass replaces dict configs, loss classes were refactored to a model_config-based API, pipeline-parallel handling was added, and loss balancing now returns structured components. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant Trainer
participant MegatronPlugin as Megatron Plugin
participant DistillCfg as DistillationConfig
participant DistillModel as DistillationModel
participant Student
participant Teacher
participant Losses as LossFns/Balancer
User->>Trainer: start training
Trainer->>MegatronPlugin: load_distillation_config(path, student_cfg, teacher_cfg)
MegatronPlugin-->>Trainer: DistillationConfig
Trainer->>MegatronPlugin: adjust_distillation_model_for_mcore(DistillationModel, DistillationConfig)
MegatronPlugin->>DistillModel: patch for PP, hide teacher, LM-loss bypass
Note over DistillModel: Pipeline-aware forward hooks installed
loop each batch
Trainer->>DistillModel: forward(inputs)
activate DistillModel
DistillModel->>Student: forward(student_inputs)
DistillModel->>Teacher: forward(teacher_inputs)
Teacher-->>DistillModel: teacher outputs
Student-->>DistillModel: student outputs
DistillModel-->>Trainer: concatenated or student-only outputs (PP-aware)
deactivate DistillModel
Trainer->>DistillModel: compute_kd_loss(**loss_fn_kwargs)
DistillModel->>Losses: per-layer losses(out_s, out_t, **kwargs)
Losses-->>DistillModel: logits/intermediate losses
DistillModel->>Losses: balance(loss_dict, scales, skip_balancer?)
Losses-->>DistillModel: {kd_loss, logits_loss, intermediate_loss}
DistillModel-->>Trainer: kd_loss or dict
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #319 +/- ##
==========================================
+ Coverage 73.87% 73.88% +0.01%
==========================================
Files 172 172
Lines 17439 17443 +4
==========================================
+ Hits 12883 12888 +5
+ Misses 4556 4555 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/distill/distillation_model.py (1)
271-276
: Potential incompatibility: loss_fn must return a Tensor.Several new Megatron losses return a tuple (loss, tp_reduce, is_sequence_parallel). Passing that through here will break balancing and masking.
Two safe options:
- A) Keep loss fns returning Tensor only (preferred for compatibility).
- B) Teach compute_kd_loss to accept tuples and normalize to a Tensor before reduction/balancing.
If you choose (A), see my Megatron comments to revert post_forward to return only a Tensor.
If (B), I can draft a minimal normalizer that handles TP/SP flags. Want me to?modelopt/torch/distill/plugins/megatron.py (1)
372-421
: Loss balancer returns a dict and assumes original loss exists; both will break training.Issues:
- forward must return a scalar Tensor per DistillationLossBalancer contract.
- Unconditional pop of student loss raises KeyError when not provided or when skip_original_loss is True.
- Summing Tensors via Python sum with empty start causes type errors.
- Comparing Tensors with “> 0” is ambiguous if not 0‑dim.
Apply robust, contract‑preserving changes:
- def forward(self, loss_dict: dict[str, Tensor]) -> Tensor: + def forward(self, loss_dict: dict[str, Tensor]) -> Tensor: """Forward function. @@ - original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY) - for _key in loss_dict: - if _key.startswith(LogitsKLLoss.__name__): - logits_key = _key # should only be one - logits_loss = loss_dict.pop(logits_key) - intermediate_loss = sum(loss_dict.values()) / max(len(loss_dict), 1) + # Work on a copy to avoid mutating caller state. + loss_dict = dict(loss_dict) + original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY, None) + # Extract logits loss + logits_keys = [k for k in loss_dict if k.startswith(LogitsKLLoss.__name__)] + if len(logits_keys) != 1: + raise ValueError(f"Expected exactly one logits loss, found: {logits_keys}") + logits_loss = loss_dict.pop(logits_keys[0]) + # Normalize to scalars + def _to_scalar(t: Tensor) -> Tensor: + return t.mean() if t.dim() > 0 else t + logits_loss = _to_scalar(logits_loss) + interm_values = list(loss_dict.values()) + if interm_values: + intermediate_loss = _to_scalar(torch.stack([_to_scalar(v) for v in interm_values]).mean()) + else: + intermediate_loss = torch.zeros_like(logits_loss) @@ - if intermediate_loss > 0: - dynamic_scale = logits_loss.item() / intermediate_loss.item() - intermediate_loss_scaled = intermediate_loss * dynamic_scale - else: - intermediate_loss = logits_loss.new_tensor(intermediate_loss) - intermediate_loss_scaled = intermediate_loss + if intermediate_loss.detach().abs().item() > 0: + dynamic_scale = logits_loss.detach().item() / intermediate_loss.detach().item() + intermediate_loss_scaled = intermediate_loss * dynamic_scale + else: + intermediate_loss_scaled = torch.zeros_like(logits_loss) @@ - if self._skip_original_loss: - total_loss = logits_loss + intermediate_loss_scaled - else: - kd_loss = logits_loss + intermediate_loss_scaled - kd_loss *= original_loss.item() / kd_loss.item() - total_loss = original_loss + kd_loss * self._kd_loss_scale + if self._skip_original_loss or original_loss is None: + total_loss = logits_loss + intermediate_loss_scaled + else: + kd_loss = logits_loss + intermediate_loss_scaled + kd_loss = kd_loss * (original_loss.detach().item() / max(kd_loss.detach().item(), 1e-12)) + total_loss = original_loss + kd_loss * self._kd_loss_scale @@ - out_dict = { - "kd_loss": total_loss, - "logits_loss": logits_loss, - "intermediate_loss": intermediate_loss, - } - return out_dict + # Optional: expose components for logging + self.last_components = { + "total_loss": total_loss.detach(), + "logits_loss": logits_loss.detach(), + "intermediate_loss": intermediate_loss.detach(), + } + return total_loss
🧹 Nitpick comments (4)
modelopt/torch/distill/distillation_model.py (2)
251-257
: Docstring still references “reduce=True/False”.The public API uses skip_balancer, not reduce. Update to avoid confusion.
Apply:
- Returns: - If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses. - If reduce is False, a dict of student model output loss and layer-wise distillation losses. + Returns: + If ``skip_balancer`` is False, the total loss as returned by the configured loss balancer + (typically a scalar Tensor). If ``skip_balancer`` is True, a dict of student loss and + layer‑wise distillation losses.
123-127
: Type hint nit: teacher_model is not a ModuleList.Return annotation should be nn.Module.
Apply:
- def teacher_model(self) -> nn.ModuleList: + def teacher_model(self) -> nn.Module:modelopt/torch/distill/plugins/megatron.py (2)
50-69
: DistillationConfig: make internal fields non‑init to prevent YAML collisions.criterion and loss_balancer are derived; disallow user‑provided values to avoid accidental overrides.
- criterion: Criterion | None = None - loss_balancer: mtd.DistillationLossBalancer | None = None + criterion: Criterion | None = field(default=None, init=False, repr=False) + loss_balancer: mtd.DistillationLossBalancer | None = field(default=None, init=False, repr=False)
553-616
: Tensor shapes adjuster: safe against empty shapes.Guard for empty recv/send shapes to avoid index errors in corner scheduling cases.
- def adjust_tensor_shapes( + def adjust_tensor_shapes( recv_tensor_shapes: list[tuple[int, ...]], send_tensor_shapes: list[tuple[int, ...]] ): + if not recv_tensor_shapes or not send_tensor_shapes: + return recv_tensor_shapes, send_tensor_shapes
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/distill/distillation_model.py
(3 hunks)modelopt/torch/distill/plugins/megatron.py
(10 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/distill/plugins/megatron.py (2)
modelopt/torch/distill/distillation_model.py (5)
loss_balancer
(134-136)DistillationModel
(37-288)hide_teacher_model
(139-147)teacher_model
(124-126)only_student_forward
(171-178)modelopt/torch/distill/loss_balancers.py (1)
DistillationLossBalancer
(33-71)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: wait-checks / wait
🔇 Additional comments (2)
modelopt/torch/distill/plugins/megatron.py (1)
488-551
: Guard set_input_tensor calls or ensure the teacher implements it.Repo search found no
def set_input_tensor
— the calltype(self).set_input_tensor(self.teacher_model, ...)
in modelopt/torch/distill/plugins/megatron.py (_set_input_tensor) will raise if missing; add anhasattr
guard and fallback (e.g. callself.teacher_model.set_input_tensor(...)
if present) or raise a clear error.modelopt/torch/distill/distillation_model.py (1)
242-244
: Approve: kwargs passthrough OK — no call sites pass positional extras.
Scanned repo for compute_kd_loss usages; all calls either pass keyword args or no args: tests/unit/torch/distill/test_distill.py (lines 94, 111, 127, 188); examples/chained_optimizations/bert_prune_distill_quantize.py:1035; modelopt/torch/distill/plugins/huggingface.py:96; modelopt/torch/quantization/plugins/transformers_trainer.py:402.
def post_forward( | ||
self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False | ||
) -> Tensor: | ||
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" | ||
loss = loss.transpose(0, 1).contiguous() | ||
return (loss, tp_reduce) | ||
return (loss, tp_reduce, is_sequence_parallel) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Breaking change: BaseLoss.post_forward now returns a tuple.
This violates the expected Loss API (forward → Tensor) and breaks DistillationModel.compute_kd_loss which assumes Tensor. Revert to returning a Tensor; encode TP/SP concerns elsewhere.
- def post_forward(
- self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False
- ) -> Tensor:
+ def post_forward(self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False) -> Tensor:
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking."""
loss = loss.transpose(0, 1).contiguous()
- return (loss, tp_reduce, is_sequence_parallel)
+ return loss
And adjust callers (see comments below).
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def post_forward( | |
self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False | |
) -> Tensor: | |
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" | |
loss = loss.transpose(0, 1).contiguous() | |
return (loss, tp_reduce) | |
return (loss, tp_reduce, is_sequence_parallel) | |
def post_forward(self, loss: Tensor, tp_reduce: bool = False, is_sequence_parallel: bool = False) -> Tensor: | |
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" | |
loss = loss.transpose(0, 1).contiguous() | |
return loss |
🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 181 to 187,
post_forward currently returns a tuple (loss, tp_reduce, is_sequence_parallel)
which breaks the Loss API that expects a Tensor; revert post_forward to return
only the Tensor (i.e., transpose and contiguous as before) and remove the tuple
packaging; move TP/SP flags out of this return (e.g., set attributes on the
plugin instance, provide accessor methods, or pass flags via the distillation
caller/compute_kd_loss invocation) and update any callers (notably
DistillationModel.compute_kd_loss and other code expecting a Tensor) to retrieve
TP/SP information from the new location rather than from post_forward's return
value.
def __init__( | ||
self, model_config: "TransformerConfig", projection_layer: nn.Module | None = None | ||
): | ||
"""Constructor. | ||
Args: | ||
student_config: Student's MCore transformer config. | ||
teacher_config: Teacher's MCore transformer config. | ||
model_config: MCore transformer config. | ||
projection_layer: Module which projects student activations to teacher's hidden dim. | ||
""" | ||
super().__init__(student_config, teacher_config, projection_layer=True) | ||
super().__init__(model_config, projection_layer=projection_layer) | ||
|
||
if self._tensor_parallel and not self._sequence_parallel: | ||
if self._config.tensor_model_parallel_size > 1: | ||
logger.warning( | ||
"``HiddenStateCosineLoss`` only works with tensors with full hidden dim. Ensure the " | ||
"tensor inputs meet this requirement or use `--sequence_parallel` if tensor parallel is enabled." | ||
"tensor inputs meet this requirement. We recommend only applying this loss to LayerNorm outputs, " | ||
"which have full hidden dim even when TP is used." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align HiddenStateCosineLoss with Tensor‑only contract.
Return a Tensor from forward; remove tuple propagation.
- return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel)
+ return self.post_forward(loss)
Also applies to: 255-257
🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 217-233 (and also
apply same change at 255-257), the HiddenStateCosineLoss implementation
currently propagates a tuple output but the contract requires returning a single
Tensor from forward; update the forward method to return only the loss Tensor
(not a tuple) and remove any tuple wrapping or propagation in __init__ or helper
methods so all call sites receive a Tensor; ensure type hints and docstring
reflect Tensor return and update any downstream unpacking to accept a single
Tensor.
def __init__( | ||
self, | ||
student_config: TransformerConfig, | ||
teacher_config: TransformerConfig, | ||
temperature: float = 1.0, | ||
reverse: bool = False, | ||
self, model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False | ||
): | ||
"""Constructor. | ||
Args: | ||
student_config: Student's MCore transformer config. | ||
teacher_config: Teacher's MCore transformer config. | ||
model_config: MCore transformer config. | ||
temperature: Divide tensors by this value prior to calculating loss. | ||
reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) | ||
""" | ||
super().__init__(student_config, teacher_config) | ||
super().__init__(model_config) | ||
self._temperature = temperature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align LogitsKLLoss with Tensor‑only contract.
Return a Tensor from forward; remove tuple propagation.
- return self.post_forward(loss, tp_reduce=True)
+ return self.post_forward(loss)
If TP/SP reductions are needed, handle them in the balancer (central place), not by changing loss return types.
Also applies to: 293-364
Signed-off-by: Asha Anoosheh <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Plugin feature: Updated Megatron KD plugin module
Overview: ?
Usage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Documentation